{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用 qianfan sdk 构建本地评估模型\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "千帆大模型平台提供了在线进行自动评估的方式，而有时我们希望有更灵活的自动评估方式。\n",
    "\n",
    "本文将以评估文本摘要效果为例子来介绍\n",
    "\n",
    "1. 如何使用千帆 sdk 封装自定义本地评估方法\n",
    "2. 使用函数封装手搓其原理实现。"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "以下是本文封装部分使用的库与作用\n",
    "    Dataset：提供数据集接口，用于构造待评估数据集，方便在本地、平台、huggingface多端拉取数据集。\n",
    "    Prompt：提供数据模板，用于指导模型进行评估，可联网获取prompt或者使用prompt优化功能。\n",
    "    Service：提供模型服务接口，在evaluateManager中为没有生成答案的数据集提供生成能力。\n",
    "    EvaluateManager：提供评估任务管理功能，封装了异步调用算子，可以方便地实现多线程评估。"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此处使用0.2.8版本以上的qianfan sdk。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -U \"qianfan>=0.2.8\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "进行鉴权认证。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-25T07:41:10.465514Z",
     "start_time": "2024-01-25T07:41:10.462826Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['QIANFAN_ACCESS_KEY'] = 'your_access_key'\n",
    "os.environ['QIANFAN_SECRET_KEY'] = 'your_secret_key'\n",
    "\n",
    "# 使用 Service 对象进行评估时，请按实际情况填写 QPS LIMIT，\n",
    "# 取值为所有 Service QPS Limit中的最小值\n",
    "os.environ[\"QIANFAN_QPS_LIMIT\"] = \"1\"\n",
    "os.environ['QIANFAN_LLM_API_RETRY_COUNT'] = \"3\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据与模板准备\n",
    "\n",
    "首先构造待评估数据集，此处用本地数据集，也可以拉取千帆平台上的数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-25T07:41:11.163193Z",
     "start_time": "2024-01-25T07:41:10.466579Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'prompt': '新华社受权于18日全文播发修改后的《中华人民共和国立法法》，修改后的立法法分为“总则”“法律”“行政法规”“地方性法规、自治条例和单行条例、规章”“适用与备案审查”“附则”等6章，共计105条。', 'response': [['修改后的立法法全文公布']], '_group': 0}\n"
     ]
    }
   ],
   "source": [
    "from qianfan.dataset import Dataset\n",
    "\n",
    "ds = Dataset.load(data_file=\"data_summerize/excerpt.jsonl\", organize_data_as_group=True, input_columns=[\"prompt\"], reference_column=\"response\")\n",
    "\n",
    "print(ds[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后创建prompt模版，用于引导模型进行评估。\n",
    "\n",
    "由于摘要任务没有标准答案，所以此处模板不提供，如果是评估问答任务，则需要提供标准答案以评估正确率。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-25T07:41:11.163729Z",
     "start_time": "2024-01-25T07:41:11.161302Z"
    }
   },
   "outputs": [],
   "source": [
    "from qianfan.common import Prompt\n",
    "\n",
    "evaluation_prompt_template = \"\"\"\n",
    "你是一名裁判员，负责为给定新闻的摘要评分。\n",
    "\n",
    "评价标准：\n",
    "\n",
    "{criteria}\n",
    "\n",
    "请你遵照以下的评分步骤：\n",
    "{steps}\n",
    "\n",
    "\n",
    "例子：\n",
    "\n",
    "新闻：\n",
    "\n",
    "{prompt}\n",
    "\n",
    "摘要：\n",
    "\n",
    "{response}\n",
    "\n",
    "\n",
    "根据答案的综合水平给出0到{max_score}之间的整数评分。\n",
    "如果答案存在明显的不合理之处，则应给出一个较低的评分。\n",
    "如果答案符合以上要求并且与参考答案含义相似，则应给出一个较高的评分\n",
    "\n",
    "你的回答模版只能输出整数评分,不输出文字\n",
    "\"\"\"\n",
    "\n",
    "evaluation_prompt = Prompt(\n",
    "    name=\"evaluation_prompt\",\n",
    "    template=evaluation_prompt_template\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接着指定评估指标和评估步骤。\n",
    "\n",
    "此处从四个维度进行评估，分别是相关性、连贯性、一致性、流畅度"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "相关性"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-25T07:41:11.169702Z",
     "start_time": "2024-01-25T07:41:11.165530Z"
    }
   },
   "outputs": [],
   "source": [
    "relevance_metric = \"\"\"\n",
    "相关性 - 摘要中重要内容的选择。 \\ \n",
    "摘要只应包含来自源文档的重要信息。 \\\n",
    "惩罚包含冗余和多余信息的摘要。\n",
    "\"\"\"\n",
    "\n",
    "relevance_steps = \"\"\"\n",
    "1. 仔细阅读摘要和源文档。\n",
    "2. 将摘要与源文档进行比较，并识别文章的主要观点。\n",
    "3. 评估摘要是否涵盖了文章的主要观点，以及它包含多少无关或冗余信息。\n",
    "\"\"\"\n",
    "\n",
    "relevance_max_score = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "连贯性"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-25T07:41:11.170052Z",
     "start_time": "2024-01-25T07:41:11.167092Z"
    }
   },
   "outputs": [],
   "source": [
    "\n",
    "coherence_metric = \"\"\"\n",
    "连贯性 - 所有句子的总体质量。 \\\n",
    "我们将此维度与 DUC 质量问题结构性和连贯性相关， \\\n",
    "其中“摘要应具有良好的结构和组织，不应只是相关信息的堆叠，而应基于从句子到主题有关的信息的连贯性主体进行构建。” \\\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "coherence_steps = \"\"\"\n",
    "1. 仔细阅读文章并识别主要主题和关键点。\n",
    "2. 阅读摘要并与文章进行比较。检查摘要是否涵盖了文章的主要主题和关键点，并以清晰且逻辑的顺序呈现它们\n",
    "\"\"\"\n",
    "\n",
    "coherence_max_score = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "一致性"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-25T07:41:11.176007Z",
     "start_time": "2024-01-25T07:41:11.170758Z"
    }
   },
   "outputs": [],
   "source": [
    "consistency_metric = \"\"\"\n",
    "一致性 - 摘要与摘要源的客观对齐。 \\\n",
    "客观一致的摘要只包含与源文档支持的陈述。 \\\n",
    "惩罚包含幻觉事实的摘要。\n",
    "\"\"\"\n",
    "\n",
    "consistency_steps = \"\"\"\n",
    "1. 仔细阅读文章并识别主要事实和细节。\n",
    "2. 阅读摘要并与文章进行比较。检查摘要是否包含任何不支持的文章的事实错误。\n",
    "3. 基于评估标准，为一致性分配分数。\n",
    "\"\"\"\n",
    "\n",
    "consistency_max_score = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "流畅度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-25T07:41:11.176392Z",
     "start_time": "2024-01-25T07:41:11.173566Z"
    }
   },
   "outputs": [],
   "source": [
    "fluency_metric = \"\"\"\n",
    "流畅度：摘要的语法、拼写、标点、单词选择和句子结构的质量。\n",
    "1：差。摘要有很多错误，使它难以理解或听起来不自然。\n",
    "2：中。摘要有一些影响文本清晰度或流畅性的错误，但要点仍然可理解。\n",
    "3：好。摘要很少或没有错误，易于阅读和遵循。\n",
    "\"\"\"\n",
    "\n",
    "fluency_steps = \"\"\"\n",
    "读取摘要并基于给定的标准评估其流畅度。从 1 到 3 分配一个流畅度分数。\n",
    "\"\"\"\n",
    "\n",
    "fluency_max_score = 3"
   ]
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "evaluation_metrics = {\n",
    "    \"Relevance\": (relevance_metric, relevance_metric, relevance_max_score),\n",
    "    \"Coherence\": (coherence_metric, coherence_steps, coherence_max_score),\n",
    "    \"Consistency\": (consistency_metric, consistency_steps, consistency_max_score),\n",
    "    \"Fluency\": (fluency_metric, fluency_steps, fluency_max_score),\n",
    "}"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-25T07:41:11.182006Z",
     "start_time": "2024-01-25T07:41:11.176493Z"
    }
   },
   "execution_count": 9
  },
  {
   "cell_type": "markdown",
   "source": [
    "通过千帆sdk，我们有两种方法实现"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 方法一：封装本地评估器\n",
    "\n",
    "使用千帆平台提供的本地评估工具类进行评估\n",
    "此方法封装了异步请求，可以根据数据量并发进行评估\n",
    "\n",
    "首先继承LocalEvaluator类，并实现evaluate方法"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "from typing import Any, Dict, List, Union\n",
    "from qianfan.utils.pydantic import Field\n",
    "from qianfan.evaluation.consts import (\n",
    "    QianfanRefereeEvaluatorDefaultMaxScore,\n",
    "    QianfanRefereeEvaluatorDefaultMetrics,\n",
    "    QianfanRefereeEvaluatorDefaultSteps,\n",
    ")\n",
    "from qianfan.evaluation.evaluator import LocalEvaluator\n",
    "\n",
    "class QianfanLocalEvaluator(LocalEvaluator):\n",
    "\n",
    "    model: Any\n",
    "    metric_name: str = Field(default=\"\")\n",
    "    prompt_metrics: str = Field(default=QianfanRefereeEvaluatorDefaultMetrics)\n",
    "    prompt_steps: str = Field(default=QianfanRefereeEvaluatorDefaultSteps)\n",
    "    prompt_max_score: int = Field(default=QianfanRefereeEvaluatorDefaultMaxScore)\n",
    "\n",
    "    # evaluate三个参数，input是输入，reference是模型输出的结果，output是参考答案。\n",
    "    # 当evaluationManager调用eval时，如果is_chat_service=True，则input为对话列表 List[Dict[str, Any]]，否则为文本 str。\n",
    "    # 在该任务中，input是新闻，output是摘要，reference是service生成的摘要\n",
    "    # 由于我们只是对新闻摘要进行评分，因此只需要用到input和output\n",
    "    def evaluate(\n",
    "            self, input: Union[str, List[Dict[str, Any]]], reference: str, output: str\n",
    "        ) -> Dict[str, Any]:\n",
    "        # 将input调整为合适格式\n",
    "        input_content = input if isinstance(input, str) else input[0].get('content','')\n",
    "        # 生成评价模板\n",
    "        prompt, _ = evaluation_prompt.render(\n",
    "            criteria=self.prompt_metrics,\n",
    "            steps=self.prompt_steps,\n",
    "            prompt=input_content,\n",
    "            response=output,\n",
    "            max_score=str(self.prompt_max_score),\n",
    "            metric_name=self.metric_name,\n",
    "        )\n",
    "        # 调用模型获得评分\n",
    "        msg = qianfan.Messages()\n",
    "        msg.append(prompt)\n",
    "        resp = self.model.do(\n",
    "            messages=msg,\n",
    "            temperature=0.1,\n",
    "            top_p=1,\n",
    "        )\n",
    "        # print(f'{self.metric_name}|{input[0]}|{output}|{resp[\"result\"].strip()}')\n",
    "        return {self.metric_name: resp['result'].strip()}\n"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-25T07:41:11.199850Z",
     "start_time": "2024-01-25T07:41:11.177983Z"
    }
   },
   "execution_count": 10
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] [01-25 15:41:11] model.py:375 [t:8678109824]: service type should be specified before exec\n"
     ]
    }
   ],
   "source": [
    "import qianfan\n",
    "from qianfan.model import Service\n",
    "\n",
    "eb_turbo_service = Service(model=\"ERNIE-Bot-turbo\")  # 加载生成回答的服务\n",
    "chat_comp = qianfan.ChatCompletion(model=\"ERNIE-Bot-4\")  # 实例化用于裁判的模型"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-25T07:41:11.200864Z",
     "start_time": "2024-01-25T07:41:11.192683Z"
    }
   },
   "execution_count": 11
  },
  {
   "cell_type": "markdown",
   "source": [
    "附加本地评估器，供EvaluationManager调用"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "local_evaluators = []\n",
    "for eval_type, (criteria, steps, max_score) in evaluation_metrics.items():\n",
    "    local_evaluator = QianfanLocalEvaluator(\n",
    "        prompt_metrics=criteria,\n",
    "        prompt_steps=steps,\n",
    "        prompt_max_score=max_score,\n",
    "        model=chat_comp,\n",
    "        metric_name=eval_type\n",
    "    )\n",
    "    local_evaluators.append(local_evaluator)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-25T07:41:11.211032Z",
     "start_time": "2024-01-25T07:41:11.201647Z"
    }
   },
   "execution_count": 12
  },
  {
   "cell_type": "code",
   "outputs": [],
   "source": [
    "from qianfan.evaluation import EvaluationManager\n",
    "\n",
    "em = EvaluationManager(local_evaluators=local_evaluators)\n",
    "result = em.eval(\n",
    "    [eb_turbo_service], ds,\n",
    "    #  is_chat_service=False # 默认为True，此时input为List[Dict[str, Any]]，否则input为str,\n",
    ")"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-25T07:41:27.946291Z",
     "start_time": "2024-01-25T07:41:11.205148Z"
    }
   },
   "execution_count": 13
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_chats': [{'content': '新华社受权于18日全文播发修改后的《中华人民共和国立法法》，修改后的立法法分为“总则”“法律”“行政法规”“地方性法规、自治条例和单行条例、规章”“适用与备案审查”“附则”等6章，共计105条。', 'role': 'user'}], 'expected_output': '修改后的立法法全文公布', 'model_content': [{'Coherence': '8', 'Consistency': '10', 'Fluency': '3', 'Relevance': '8', 'content': '是的，您说的是正确的。新华社于2023年6月18日受权播发了修改后的《中华人民共和国立法法》。这次修改后的立法法分为“总则”“法律”“行政法规”“地方性法规、自治条例和单行条例、规章”“适用与备案审查”“附则”等6章，共计105条。\\n\\n这次修改后的立法法的发布，是为了更好地规范立法活动，加强宪法实施，推进科学立法、民主立法、依法立法，提高立法质量和效率，不断完善以宪法为核心的中国特色社会主义法律体系。\\n\\n同时，这次修改后的立法法也加强了对国家立法和地方立法的指导，明确了立法原则、立法程序、法律解释、备案审查等方面的规定，为各级人大及其常委会依法行使立法权提供了更加明确的法律依据和程序保障。\\n\\n总之，这次修改后的立法法的发布，对于推进全面依法治国、建设社会主义法治国家具有重要的意义。', 'llm_tag': 'None_None_ERNIE-Bot-turbo'}]}\n",
      "{'input_chats': [{'content': '一辆小轿车，一名女司机，竟造成9死24伤。日前，深圳市交警局对事故进行通报：从目前证据看，事故系司机超速行驶且操作不当导致。目前24名伤员已有6名治愈出院，其余正接受治疗，预计事故赔偿费或超一千万元。', 'role': 'user'}], 'expected_output': '深圳机场9死24伤续：司机全责赔偿或超千万', 'model_content': [{'Coherence': '8', 'Consistency': '8', 'Fluency': '2', 'Relevance': '8', 'content': '这起事故是非常悲惨的，对受害者和家人造成了巨大的痛苦和损失。对于事故的原因，交警局的通报已经给出了明确的解释，即司机超速行驶且操作不当。这些因素导致了车辆失控，从而导致了这起惨剧。\\n\\n在这种情况下，我们应该深刻反思交通安全的重要性，加强交通安全教育和宣传，提高公众的交通安全意识，避免类似的悲剧再次发生。同时，对于事故赔偿，我们应该根据法律法规进行，确保受害者和家人的权益得到充分保障。\\n\\n总之，这起事故再次提醒我们，交通安全是每个人都应该关注的问题，我们应该时刻保持警惕，遵守交通规则，避免超速行驶等危险行为，共同营造一个安全的交通环境。', 'llm_tag': 'None_None_ERNIE-Bot-turbo'}]}\n",
      "{'input_chats': [{'content': '1月18日，习近平总书记对政法工作作出重要指示：2014年，政法战线各项工作特别是改革工作取得新成效。新形势下，希望全国政法机关主动适应新形势，为公正司法和提高执法司法公信力提供有力制度保障。', 'role': 'user'}], 'expected_output': '孟建柱：主动适应形势新变化提高政法机关服务大局的能力', 'model_content': [{'Coherence': '5', 'Consistency': '3', 'Fluency': '2', 'Relevance': '5', 'content': '关于这个问题，您可以参阅相关内容网站，您也可以问我一些其他问题，我会尽力为您解答。', 'llm_tag': 'None_None_ERNIE-Bot-turbo'}]}\n"
     ]
    }
   ],
   "source": [
    "result_dataset = result.result_dataset\n",
    "result_df = {\"News id\":[],\"Evaluation Type\":[],\"Score\":[]}\n",
    "for i, resp in enumerate(result_dataset.list()):\n",
    "    for metric_name in evaluation_metrics.keys():\n",
    "        result_df[\"News id\"].append(i)\n",
    "        # 添加评估类型及对应得分，因为此处只用了一个大模型，所以只取第一个结果\n",
    "        result_df[\"Evaluation Type\"].append(metric_name)\n",
    "        result_df[\"Score\"].append(resp[\"model_content\"][0][metric_name])\n",
    "    print(resp)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-25T07:41:27.953557Z",
     "start_time": "2024-01-25T07:41:27.949460Z"
    }
   },
   "execution_count": 14
  },
  {
   "cell_type": "code",
   "outputs": [
    {
     "data": {
      "text/plain": "Evaluation Type  Coherence  Consistency  Fluency  Relevance  Mean_Score\nNews id                                                                \n0                        8           10        3          8        7.25\n1                        8            8        2          8        6.50\n2                        5            3        2          5        3.75",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th>Evaluation Type</th>\n      <th>Coherence</th>\n      <th>Consistency</th>\n      <th>Fluency</th>\n      <th>Relevance</th>\n      <th>Mean_Score</th>\n    </tr>\n    <tr>\n      <th>News id</th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>0</th>\n      <td>8</td>\n      <td>10</td>\n      <td>3</td>\n      <td>8</td>\n      <td>7.25</td>\n    </tr>\n    <tr>\n      <th>1</th>\n      <td>8</td>\n      <td>8</td>\n      <td>2</td>\n      <td>8</td>\n      <td>6.50</td>\n    </tr>\n    <tr>\n      <th>2</th>\n      <td>5</td>\n      <td>3</td>\n      <td>2</td>\n      <td>5</td>\n      <td>3.75</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "pivot_df = pd.DataFrame(result_df, index=None).pivot(\n",
    "    columns=\"Evaluation Type\", index=\"News id\", values=\"Score\"\n",
    ").astype(int)\n",
    "pivot_df['Mean_Score'] = pivot_df.mean(axis=1)\n",
    "display(pivot_df)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-25T07:41:27.976189Z",
     "start_time": "2024-01-25T07:41:27.954714Z"
    }
   },
   "execution_count": 15
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 方法二：封装评估函数\n",
    "\n",
    "除了上述进行本地评估的方法，也可以用ChatCompletion模型构造评估函数，从头开始制作评估的流程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-25T07:41:27.976829Z",
     "start_time": "2024-01-25T07:41:27.970762Z"
    }
   },
   "outputs": [],
   "source": [
    "import qianfan\n",
    "\n",
    "def get_geval_score(chat_comp, evaluation_prompt, **kwargs):\n",
    "    prompt, _ = evaluation_prompt.render(**kwargs)\n",
    "    msg = qianfan.Messages()\n",
    "    msg.append(prompt, role='user')\n",
    "    resp = chat_comp.do(\n",
    "        messages=msg,\n",
    "        temperature=0.1,\n",
    "        top_p=1,\n",
    "    )\n",
    "    return resp['result']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后遍历数据集进行评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-25T07:41:42.668542Z",
     "start_time": "2024-01-25T07:41:27.975559Z"
    }
   },
   "outputs": [],
   "source": [
    "chat_comp = qianfan.ChatCompletion(model=\"ERNIE-Bot-4\")\n",
    "result = {\"Evaluation Type\": [], \"News id\": [], \"Score\": []}\n",
    "for eval_type, (criteria, steps, max_score) in evaluation_metrics.items():\n",
    "    for ind, data in enumerate(ds):\n",
    "        result[\"Evaluation Type\"].append(eval_type)\n",
    "        result[\"News id\"].append(ind)\n",
    "        evaluate = get_geval_score(\n",
    "            chat_comp,\n",
    "            evaluation_prompt,\n",
    "            criteria=criteria,\n",
    "            steps=steps,\n",
    "            prompt=data['prompt'],\n",
    "            response=data['response'][0][0],\n",
    "            max_score=str(max_score),\n",
    "            metric_name=eval_type\n",
    "        )\n",
    "        score = int(evaluate.strip())\n",
    "        result[\"Score\"].append(evaluate)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "查看最终结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-25T07:41:42.705224Z",
     "start_time": "2024-01-25T07:41:42.682140Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": "Evaluation Type  Coherence  Consistency  Fluency  Relevance  Mean_Score\nNews id                                                                \n0                        8           10        3          9        7.50\n1                        8            8        2          8        6.50\n2                        5            3        2          5        3.75",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th>Evaluation Type</th>\n      <th>Coherence</th>\n      <th>Consistency</th>\n      <th>Fluency</th>\n      <th>Relevance</th>\n      <th>Mean_Score</th>\n    </tr>\n    <tr>\n      <th>News id</th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n      <th></th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>0</th>\n      <td>8</td>\n      <td>10</td>\n      <td>3</td>\n      <td>9</td>\n      <td>7.50</td>\n    </tr>\n    <tr>\n      <th>1</th>\n      <td>8</td>\n      <td>8</td>\n      <td>2</td>\n      <td>8</td>\n      <td>6.50</td>\n    </tr>\n    <tr>\n      <th>2</th>\n      <td>5</td>\n      <td>3</td>\n      <td>2</td>\n      <td>5</td>\n      <td>3.75</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pivot_df = pd.DataFrame(result, index=None).pivot(\n",
    "    columns=\"Evaluation Type\", index=\"News id\", values=\"Score\"\n",
    ").astype(int)\n",
    "pivot_df['Mean_Score'] = pivot_df.mean(axis=1)\n",
    "display(pivot_df)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
